# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import yaml
from op_build_gen import gen_build_func_str
from op_interface_gen import (
    gen_exclusive_interface_str,
    gen_op_infer_meta_str,
    vjp_interface_gen_op_list,
)
from op_member_func_gen import gen_op_get_inputs_outputs_str
from op_verify_gen import gen_verify_func_str

# =====================================
# String Template for h file code gen
# =====================================
NAMESPACE_GARD_TEMPLATE = """namespace {namespace} {{
{input}
}} // namespace {namespace}"""

H_FILE_TEMPLATE = """#ifdef GET_OP_LIST
#undef GET_OP_LIST
{op_declare}
#else
// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py"

#include <vector>

#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/operation_utils.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/dialect/op_yaml_info_util.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/interface/infermeta.h"
#include "paddle/fluid/ir/interface/vjp.h"
#include "paddle/fluid/ir/trait/inplace.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/fluid/ir/dialect/pd_manual_op.h"

{input}

{declare_type_id}
#endif
"""

GET_OP_LIST_TEMPALTE = """{}
"""

DECLARE_OP_TYPE_ID = """
IR_DECLARE_EXPLICIT_TYPE_ID({op_name})
"""

OP_DECLARE_TEMPLATE = """
class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
 public:
  using Op::Op;
  static const char *name() {{ return "{dialect_op_name}"; }}
  {attribute_declare}
  static constexpr uint32_t attributes_num = {attribute_num};
  static OpInfoTuple GetOpInfo();
  static void Build({build_args});
  {build_mutable_attr_is_input}
  {build_attr_num_over_1}
  void Verify();
{get_inputs_and_outputs}
{exclusive_interface}
}};
"""
op_0_attribute_declare_str = (
    "static constexpr const char **attributes_name = nullptr;"
)
op_n_attribute_declare_str = (
    "static const char *attributes_name[{attribute_num}];"
)

# =====================================
# String Template for cc file code gen
# =====================================
CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py"

#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/primitive/type/static_tensor.h"
#include "paddle/ir/core/op_base.h"

{input}

{define_type_id}
"""

OP_N_ATTRIBUTE_DEFINED_TEMPLATE = """
const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }};
"""

# get op info
OP_INFO_TEMPLATE = """
OpInfoTuple {op_name}::GetOpInfo() {{
  std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
  std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
  std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
  paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{"{kernel_key_dtype}"}}, {{{inplace}}}, {{{view}}});

  return std::make_tuple(inputs, attributes, outputs, run_time_info);
}}
"""
CONSTRUCT_INPUT_INFO_TEMPLATE = """OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute})"""
CONSTRUCT_OUTPUT_INFO_TEMPLATE = (
    """OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})"""
)
CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = (
    """OpAttributeInfo("{name}", "{typename}", "{data_type}")"""
)


DEFINE_OP_TYPE_ID = """
IR_DEFINE_EXPLICIT_TYPE_ID({op_name})
"""

scalar_type_maps = {
    'int': 'ir::Int32Attribute',
    'int64_t': 'ir::Int64Attribute',
    'float': 'ir::FloatAttribute',
    'dobule': 'ir::DoubleAttribute',
    'bool': 'ir::BoolAttribute',
}

_NO_NEED_GEN_OPS = {'add_n'}


def to_phi_and_fluid_op_name(op_item):
    # Templat: - op : phi_name (fluid_name)
    names = op_item.split('(')
    if len(names) == 1:
        phi_fluid_name = names[0].strip()
        return phi_fluid_name, phi_fluid_name
    else:
        phi_name = names[0].strip()
        fluid_name = names[1].split(')')[0].strip()
        return phi_name, fluid_name


def to_phi_and_fluid_grad_op_name(op_item):
    # Templat: sum_grad (reduce_sum_grad), sum_double_grad
    rtn = []
    all_names = op_item.split(', ')
    for name in all_names:
        backward_phi_name, backward_fluid_name = to_phi_and_fluid_op_name(name)
        rtn.append([backward_phi_name, backward_fluid_name])
    return rtn


# =====================================
# Parse Op Compat From Yaml
# =====================================
class OpCompatParser:
    def __init__(self, ops_compat_yaml_file):
        self.ops_compat_yaml_file = ops_compat_yaml_file
        with open(self.ops_compat_yaml_file, "r") as f:
            self.ops_compat = yaml.safe_load(f)

    def get_compat(self, op_name):
        for compat in self.ops_compat:
            forward_phi_name, forward_fluid_name = to_phi_and_fluid_op_name(
                compat['op']
            )
            if op_name == forward_phi_name:
                return compat
            elif 'backward' in compat.keys():
                bkw_names = to_phi_and_fluid_grad_op_name(compat['backward'])
                for name in bkw_names:
                    if op_name == name[0]:
                        return compat
        return None


# =====================================
# Parse Op Information From Yaml
# =====================================
class OpInfoParser:
    def __init__(self, op_yaml_item, op_compat_item):
        self.op_yaml_item = op_yaml_item
        self.op_compat_item = op_compat_item
        self.op_phi_name = self.parse_op_phi_name()
        # parse inputs
        self.input_name_list = self.parse_input_name_list()
        self.input_type_list = self.parse_input_type_list()
        self.input_optional_list = self.parse_input_optional_list()
        self.input_no_need_buffer_list = self.parse_input_no_need_buffer_list()
        self.cross_check(
            self.input_name_list, self.input_type_list, self.input_optional_list
        )

        # parse outputs
        self.output_name_list = self.parse_output_name_list()
        self.output_type_list = self.parse_output_type_list()
        self.output_size_list = self.parse_output_size_list()
        self.output_optional_list = self.parse_output_optional_list()
        self.output_intermediate_list = self.parse_output_intermediate_list()
        self.cross_check(
            self.output_name_list,
            self.output_type_list,
            self.output_optional_list,
        )

        # parse attributes
        self.attr_types_map = {
            'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'],
            'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'],
            'Scalar(int)': ['ir::Int32Attribute', 'int'],
            'Scalar(int64_t)': ['ir::Int64Attribute', 'int64_t'],
            'Scalar(float)': ['ir::FloatAttribute', 'float'],
            'Scalar(dobule)': ['ir::DoubleAttribute', 'dobule'],
            'Scalar[]': [
                'ir::ArrayAttribute<paddle::dialect::ScalarAttribute>',
                'const std::vector<Scalar>&',
            ],
            'int': ['ir::Int32Attribute', 'int'],
            'int32_t': ['ir::Int32Attribute', 'int32_t'],
            'int64_t': ['ir::Int64Attribute', 'int64_t'],
            'long': ['ir::LongAttribute', 'long'],
            'size_t': ['ir::Size_tAttribute', 'size_t'],
            'float': ['ir::FloatAttribute', 'float'],
            'float[]': [
                'ir::ArrayAttribute<ir::FloatAttribute>',
                'const std::vector<float>&',
            ],
            'double': ['ir::DoubleAttribute', 'double'],
            'bool': ['ir::BoolAttribute', 'bool'],
            'bool[]': [
                'ir::ArrayAttribute<ir::BoolAttribute>',
                'const std::vector<bool>&',
            ],
            'str': ['ir::StrAttribute', 'const std::string&'],
            'str[]': [
                'ir::ArrayAttribute<ir::StrAttribute>',
                'const std::vector<std::string>&',
            ],
            'Place': ['paddle::dialect::PlaceAttribute', 'const Place&'],
            'DataLayout': [
                'paddle::dialect::DataLayoutAttribute',
                'DataLayout',
            ],
            'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'],
            'int64_t[]': [
                'ir::ArrayAttribute<ir::Int64Attribute>',
                'const std::vector<int64_t>&',
            ],
            'int[]': [
                'ir::ArrayAttribute<ir::Int32Attribute>',
                'const std::vector<int>&',
            ],
        }
        self.attribute_name_list = self.parse_attribute_name_list()
        self.attribute_type_list = self.parse_attribute_type_list()
        self.attribute_build_arg_type_list = (
            self.parse_attribute_build_arg_type_list()
        )
        self.attribute_data_type_list = self.parse_attribute_data_type_list()
        self.attribute_default_value_list = (
            self.parse_attribute_default_value_list()
        )
        self.cross_check(self.attribute_name_list, self.attribute_type_list)

        # parse mutable attributes (as inputs)
        (
            self.mutable_attribute_name_list,
            self.mutable_attribute_type_list,
        ) = self.parse_mutable_attribute()

        (
            self.non_mutable_attribute_name_list,
            self.non_mutable_attribute_type_list,
            self.non_mutable_attribute_data_type_list,
            self.non_mutable_attribute_build_arg_type_list,
            self.non_mutable_attribute_default_value_list,
        ) = self.parse_non_nutable_attribute()

        # parse infermeta && kernel
        self.infer_meta_map = self.parse_infer_meta_map()
        self.kernel_map = self.parse_kernel_map()
        if 'infer_meta' in self.op_yaml_item:
            self.infer_meta_func = self.op_yaml_item['infer_meta']["func"]
        else:
            self.infer_meta_func = None

        # parse backward name
        self.backward_name = self.parse_backward_name()

        # parse inplace && view
        self.inplace_map = self.parse_op_inplace_info()
        self.view_map = self.parse_op_view_info()

        # parse has_custom_verify
        self.custom_verify = self.parse_custom_verify()

    def cross_check(self, name_list, type_list, optional_list=None):
        assert len(name_list) == len(
            type_list
        ), "name list size != type list size."
        if optional_list is not None:
            assert len(type_list) == len(
                optional_list
            ), "type list size != optional list size."

    def parse_custom_verify(self):
        if 'custom_verify' in self.op_yaml_item:
            return self.op_yaml_item['custom_verify']
        return False

    def parse_op_phi_name(self):
        if (self.parse_op_inplace_info() is None) and (
            self.parse_op_view_info() is None
        ):
            return [self.op_yaml_item['name']]
        else:
            if self.op_yaml_item['name'][-1] == "_":
                return [self.op_yaml_item['name']]
            else:
                return [
                    self.op_yaml_item['name'],
                    self.op_yaml_item['name'] + "_",
                ]

    def parse_op_inplace_info(self):
        if 'inplace' in self.op_yaml_item:
            return self.op_yaml_item['inplace']
        return None

    def parse_op_view_info(self):
        if 'view' in self.op_yaml_item:
            return self.op_yaml_item['view']
        return None

    def parse_mutable_attribute(self):
        """
        {'axis': 'paddle::dialect::ScalarAttribute', 'rotl': 'paddle::dialect::IntArrayAttribute'}
        """
        mutable_attribute_name_list = []
        mutable_attribute_type_list = []
        # scalar
        if (self.op_compat_item is not None) and (
            'scalar' in self.op_compat_item
        ):
            for scalar_attr in self.op_compat_item['scalar'].keys():
                if 'data_type' in self.op_compat_item['scalar'][scalar_attr]:
                    if (
                        scalar_attr == "depth"
                        and self.op_phi_name[0] == "one_hot"
                    ):
                        mutable_attribute_name_list.append("num_classes")
                    else:
                        mutable_attribute_name_list.append(scalar_attr)
                    data_type = self.op_compat_item['scalar'][scalar_attr][
                        'data_type'
                    ]
                    # patch for isclose and allclose
                    if (self.op_compat_item['op'] == "isclose") or (
                        self.op_compat_item['op'] == "allclose"
                    ):
                        data_type = "float"
                    mutable_attribute_type_list.append(
                        [
                            "paddle::dialect::ScalarAttribute",
                            data_type,
                        ]
                    )
                # See eye in op_compat.yaml
                else:
                    mutable_attribute_name_list.append(scalar_attr)
                    mutable_attribute_type_list.append(
                        [
                            "paddle::dialect::ScalarAttribute",
                            self.attribute_data_type_list[
                                self.attribute_name_list.index(scalar_attr)
                            ],
                        ]
                    )
        # int_array
        if (self.op_compat_item is not None) and (
            'int_array' in self.op_compat_item
        ):
            for int_array_attr in self.op_compat_item['int_array']:
                mutable_attribute_name_list.append(int_array_attr)
                mutable_attribute_type_list.append(
                    [
                        "paddle::dialect::IntArrayAttribute",
                        self.op_compat_item['int_array'][int_array_attr][
                            'data_type'
                        ],
                    ]
                )
        sorted_mutable_attribute_name_list = []
        sorted_mutable_attribute_type_list = []
        for attr_name in self.attribute_name_list:
            if attr_name in mutable_attribute_name_list:
                sorted_mutable_attribute_name_list.append(attr_name)
                sorted_mutable_attribute_type_list.append(
                    mutable_attribute_type_list[
                        mutable_attribute_name_list.index(attr_name)
                    ]
                )

        return (
            sorted_mutable_attribute_name_list,
            sorted_mutable_attribute_type_list,
        )

    def parse_non_nutable_attribute(self):
        op_non_mutable_attribute_name_list = []
        op_non_mutable_attribute_type_list = []
        op_non_mutable_attribute_data_type_list = []
        op_non_mutable_attribute_build_arg_type_list = []
        op_non_mutable_attribute_default_value_list = []
        for idx in range(len(self.attribute_name_list)):
            if (
                self.attribute_name_list[idx]
                not in self.mutable_attribute_name_list
            ):
                op_non_mutable_attribute_name_list.append(
                    self.attribute_name_list[idx]
                )
                op_non_mutable_attribute_type_list.append(
                    self.attribute_type_list[idx]
                )
                op_non_mutable_attribute_data_type_list.append(
                    self.attribute_data_type_list[idx]
                )
                op_non_mutable_attribute_build_arg_type_list.append(
                    self.attribute_build_arg_type_list[idx]
                )
                op_non_mutable_attribute_default_value_list.append(
                    self.attribute_default_value_list[idx]
                )
        return (
            op_non_mutable_attribute_name_list,
            op_non_mutable_attribute_type_list,
            op_non_mutable_attribute_data_type_list,
            op_non_mutable_attribute_build_arg_type_list,
            op_non_mutable_attribute_default_value_list,
        )

    def parse_input_name_list(self):
        name_list = []
        for input_info in self.op_yaml_item['inputs']:
            name_list.append(input_info['name'])
        return name_list

    def parse_input_type_list(self):
        input_types_map = {
            'Tensor': 'paddle::dialect::DenseTensorType',
            'Tensor[]': 'ir::VectorType<paddle::dialect::DenseTensorType>',
        }
        type_list = []
        for input_info in self.op_yaml_item['inputs']:
            assert (
                input_info['typename'] in input_types_map
            ), f"{self.op_phi_name} : Input type error: the input type only support Tensor and Tensor[], but now is {input_info['typename']}."
            type_list.append(input_types_map[input_info['typename']])
        return type_list

    def parse_input_optional_list(self):
        optional_list = []
        for input_info in self.op_yaml_item['inputs']:
            if input_info['optional']:
                optional_list.append("true")
            else:
                optional_list.append("false")
        return optional_list

    def parse_input_no_need_buffer_list(self):
        no_need_buffer_list = []
        for input_info in self.op_yaml_item['inputs']:
            if input_info['no_need_buffer']:
                no_need_buffer_list.append("true")
            else:
                no_need_buffer_list.append("false")
        return no_need_buffer_list

    def parse_output_name_list(self):
        name_list = []
        for output_info in self.op_yaml_item['outputs']:
            name_list.append(output_info['name'])
        return name_list

    def parse_output_type_list(self):
        output_type_map = {
            'Tensor': 'paddle::dialect::DenseTensorType',
            'Tensor[]': 'ir::VectorType<paddle::dialect::DenseTensorType>',
            'SelectedRows': 'paddle::dialect::SelectedRowsType',
        }
        type_list = []
        for output_info in self.op_yaml_item['outputs']:
            assert (
                output_info['typename'] in output_type_map
            ), f"{self.op_phi_name} : Output type error: the output type only support Tensor and Tensor[], but now is {output_info['typename']}."
            type_list.append(output_type_map[output_info['typename']])
        return type_list

    def parse_output_size_list(self):
        size_list = []
        for output_info in self.op_yaml_item['outputs']:
            if 'size' in output_info:
                size_list.append(output_info['size'])
            else:
                size_list.append(None)
        return size_list

    def parse_output_optional_list(self):
        optional_list = []
        for output_info in self.op_yaml_item['outputs']:
            if 'optional' in output_info:
                if output_info['optional']:
                    optional_list.append("true")
                else:
                    optional_list.append("false")
            else:
                optional_list.append("false")
        return optional_list

    def parse_output_intermediate_list(self):
        intermediate_list = []
        for output_info in self.op_yaml_item['outputs']:
            if 'intermediate' in output_info:
                if output_info['intermediate']:
                    intermediate_list.append("true")
                else:
                    intermediate_list.append("false")
            else:
                intermediate_list.append("false")
        return intermediate_list

    def parse_attribute_name_list(self):
        name_list = []
        for attribute_info in self.op_yaml_item['attrs']:
            name_list.append(attribute_info['name'])
        return name_list

    def parse_attribute_build_arg_type_list(self):
        type_list = []
        for attribute_info in self.op_yaml_item['attrs']:
            assert (
                attribute_info['typename'] in self.attr_types_map
            ), f"{self.op_phi_name} : Attr type error."

            # Scalar & IntArray has data_type
            temp_type = self.attr_types_map[attribute_info['typename']][1]
            if 'Scalar' in temp_type:
                if 'data_type' in attribute_info:
                    temp_type = attribute_info['data_type']
            if 'IntArray' in temp_type:
                if 'data_type' in attribute_info:
                    temp_type = "const " + attribute_info['data_type'] + "&"
            type_list.append(self.get_phi_dtype_name(temp_type))
        return type_list

    def parse_attribute_type_list(self):
        type_list = []
        for attribute_info in self.op_yaml_item['attrs']:
            assert (
                attribute_info['typename'] in self.attr_types_map
            ), f"{self.op_phi_name} : Attr type error."
            type_list.append(self.attr_types_map[attribute_info['typename']][0])
        return type_list

    def parse_attribute_data_type_list(self):
        data_type_list = []
        for attribute_info in self.op_yaml_item['attrs']:
            if 'data_type' in attribute_info:
                data_type_list.append(attribute_info['data_type'])
            else:
                data_type_list.append("")
        return data_type_list

    def parse_attribute_default_value_list(self):
        default_value_list = []
        for attribute_info in self.op_yaml_item['attrs']:
            if 'default_value' in attribute_info:
                default_value = attribute_info['default_value']
                default_value_list.append(
                    self.get_phi_dtype_name(default_value)
                )
            else:
                default_value_list.append(None)
        return default_value_list

    def parse_infer_meta_map(self):
        if 'infer_meta' in self.op_yaml_item:
            return self.op_yaml_item['infer_meta']
        else:
            return None

    def parse_kernel_map(self):
        if 'kernel' in self.op_yaml_item:
            return self.op_yaml_item['kernel']
        else:
            return None

    def parse_backward_name(self):
        if 'backward' in self.op_yaml_item:
            return self.op_yaml_item['backward']
        else:
            return None

    def get_phi_dtype_name(self, name):
        name = name.replace('Scalar', 'phi::Scalar')
        name = name.replace('IntArray', 'phi::IntArray')
        name = name.replace('DataLayout', 'phi::DataLayout')
        name = name.replace('DataType', 'phi::DataType')
        if name.startswith(
            (
                "Place",
                "CPUPlace",
                "GPUPlace",
                "GPUPinnedPlace",
                "XPUPlace",
                "IPUPlace",
                "CustomPlace",
            )
        ):
            return "phi::" + name
        return name


def to_pascal_case(s):
    words = s.split("_")
    if s[-1] == "_":
        return "".join([word.capitalize() for word in words]) + "_"
    else:
        return "".join([word.capitalize() for word in words]) + ""


def OpGenerator(
    op_yaml_files,
    op_compat_yaml_file,
    namespaces,
    dialect_name,
    op_def_h_file,
    op_def_cc_file,
):
    # (1) Prepare: Delete existing old files: pd_op.h.tmp, pd_op.cc.tmp
    if os.path.exists(op_def_h_file):
        os.remove(op_def_h_file)
    if os.path.exists(op_def_cc_file):
        os.remove(op_def_cc_file)

    # (2) Prepare: Get all op item in all op_yaml_files
    op_compat_parser = OpCompatParser(op_compat_yaml_file)

    op_yaml_items = []
    for yaml_file in op_yaml_files:
        with open(yaml_file, "r") as f:
            ops = yaml.safe_load(f)
            op_yaml_items = op_yaml_items + ops
    op_info_items = {}
    for op in op_yaml_items:
        op_info_items[op['name']] = OpInfoParser(
            op, op_compat_parser.get_compat(op['name'])
        )
    # (3) CodeGen: Traverse op_info_items and generate
    ops_name_list = []  # all op class name store in this list
    ops_declare_list = []  # all op class declare store in this list
    ops_defined_list = []  # all op class defined store in this list
    for key, op_info in op_info_items.items():
        # get op inputs info
        op_input_name_list = op_info.input_name_list
        op_input_type_list = op_info.input_type_list
        op_input_optional_list = op_info.input_optional_list
        op_input_no_need_buffer_list = op_info.input_no_need_buffer_list
        # get op outputs info
        op_output_name_list = op_info.output_name_list
        op_output_type_list = op_info.output_type_list
        op_output_size_list = op_info.output_size_list
        op_output_optional_list = op_info.output_optional_list
        op_output_intermediate_list = op_info.output_intermediate_list
        # get op mutable attribute
        op_mutable_attribute_name_list = op_info.mutable_attribute_name_list
        op_mutable_attribute_type_list = op_info.mutable_attribute_type_list
        # get op attribute
        op_attribute_name_list = op_info.attribute_name_list
        op_attribute_type_list = op_info.attribute_type_list
        op_attribute_data_type_list = op_info.attribute_data_type_list
        op_attribute_build_arg_type_list = op_info.attribute_build_arg_type_list
        op_attribute_default_value_list = op_info.attribute_default_value_list
        op_non_mutable_attribute_name_list = (
            op_info.non_mutable_attribute_name_list
        )
        op_non_mutable_attribute_type_list = (
            op_info.non_mutable_attribute_type_list
        )
        op_non_mutable_attribute_data_type_list = (
            op_info.non_mutable_attribute_data_type_list
        )
        op_non_mutable_attribute_build_arg_type_list = (
            op_info.non_mutable_attribute_build_arg_type_list
        )
        op_non_mutable_attribute_default_value_list = (
            op_info.non_mutable_attribute_default_value_list
        )

        # others
        op_infer_meta_map = op_info.infer_meta_map
        op_kernel_map = op_info.kernel_map
        op_inplace_map = op_info.inplace_map
        op_view_map = op_info.view_map
        op_interfaces = ["OpYamlInfoInterface"]
        op_traits = []

        if op_info.infer_meta_func:
            op_interfaces += ["InferMetaInterface"]

        if (
            op_info.backward_name
            and op_info.op_phi_name[0] in vjp_interface_gen_op_list
        ):
            op_interfaces += ["VjpInterface"]
        exclusive_interface_str = gen_exclusive_interface_str(op_info)

        # If op has inplace info, we will generate inplace op and non-inplace op.
        for op_name in op_info.op_phi_name:
            if op_name in _NO_NEED_GEN_OPS:
                continue
            op_class_name = to_pascal_case(op_name) + "Op"
            op_dialect_name = dialect_name + "." + op_name

            # =================================== #
            #    gen interface/trait list str     #
            # =================================== #
            op_interfaces_str = ""
            if len(op_interfaces) > 0:
                op_interfaces_str = "," + ",".join(op_interfaces)

            if op_name[-1] == "_":
                op_traits += ["InplaceTrait"]

            op_traits_str = ""
            if len(op_traits) > 0:
                op_traits_str = "," + ",".join(op_traits)

            # =================================== #
            #  gen get input/output methods str   #
            # =================================== #
            op_get_inputs_outputs_str = gen_op_get_inputs_outputs_str(
                op_input_name_list,
                op_mutable_attribute_name_list,
                op_output_name_list,
            )

            # =================================== #
            #         gen Build methods str       #
            # =================================== #
            build_args_with_muta_attr_not_input_for_declare = ""
            build_func_with_muta_attr_not_input = ""
            build_mutable_attr_is_input = ""
            build_attr_num_over_1 = ""
            build_func_with_attr_is_map = ""
            build_func_with_muta_attr_is_input = ""

            if op_infer_meta_map is not None:
                (
                    build_args_with_muta_attr_not_input_for_declare,
                    build_func_with_muta_attr_not_input,
                ) = gen_build_func_str(
                    op_class_name,
                    op_input_name_list,
                    op_input_type_list,
                    op_attribute_name_list,
                    op_attribute_type_list,
                    op_attribute_build_arg_type_list,
                    op_attribute_default_value_list,
                    op_mutable_attribute_name_list,
                    op_mutable_attribute_type_list,
                    op_non_mutable_attribute_name_list,
                    op_non_mutable_attribute_type_list,
                    op_non_mutable_attribute_build_arg_type_list,
                    op_non_mutable_attribute_default_value_list,
                    op_output_name_list,
                    op_output_type_list,
                    op_output_size_list,
                    op_infer_meta_map,
                    muta_attr_is_input=False,
                )
                if len(op_attribute_name_list) > 1:
                    (
                        build_args_with_attr_is_map_for_declare,
                        build_func_with_attr_is_map,
                    ) = gen_build_func_str(
                        op_class_name,
                        op_input_name_list,
                        op_input_type_list,
                        op_attribute_name_list,
                        op_attribute_type_list,
                        op_attribute_build_arg_type_list,
                        op_attribute_default_value_list,
                        op_mutable_attribute_name_list,
                        op_mutable_attribute_type_list,
                        op_non_mutable_attribute_name_list,
                        op_non_mutable_attribute_type_list,
                        op_non_mutable_attribute_build_arg_type_list,
                        op_non_mutable_attribute_default_value_list,
                        op_output_name_list,
                        op_output_type_list,
                        op_output_size_list,
                        op_infer_meta_map,
                        muta_attr_is_input=False,
                        attr_args_is_map=True,
                    )
                    build_attr_num_over_1 = (
                        "static void Build({build_args});".format(
                            build_args=build_args_with_attr_is_map_for_declare
                        )
                    )

                if len(op_mutable_attribute_name_list) > 0:
                    (
                        build_args_with_muta_attr_is_input_for_declare,
                        build_func_with_muta_attr_is_input,
                    ) = gen_build_func_str(
                        op_class_name,
                        op_input_name_list,
                        op_input_type_list,
                        op_attribute_name_list,
                        op_attribute_type_list,
                        op_attribute_build_arg_type_list,
                        op_attribute_default_value_list,
                        op_mutable_attribute_name_list,
                        op_mutable_attribute_type_list,
                        op_non_mutable_attribute_name_list,
                        op_non_mutable_attribute_type_list,
                        op_non_mutable_attribute_build_arg_type_list,
                        op_non_mutable_attribute_default_value_list,
                        op_output_name_list,
                        op_output_type_list,
                        op_output_size_list,
                        op_infer_meta_map,
                        muta_attr_is_input=True,
                    )

                    build_mutable_attr_is_input = "static void Build({build_args});".format(
                        build_args=build_args_with_muta_attr_is_input_for_declare
                    )

            # gen op_declare_str/op_defined_str
            if len(op_non_mutable_attribute_name_list) == 0:
                op_declare_str = OP_DECLARE_TEMPLATE.format(
                    op_name=op_class_name,
                    dialect_op_name=op_dialect_name,
                    interfaces=op_interfaces_str,
                    traits=op_traits_str,
                    attribute_declare=op_0_attribute_declare_str,
                    attribute_num=0,
                    build_args=build_args_with_muta_attr_not_input_for_declare,
                    build_mutable_attr_is_input=build_mutable_attr_is_input,
                    build_attr_num_over_1=build_attr_num_over_1,
                    get_inputs_and_outputs=op_get_inputs_outputs_str,
                    exclusive_interface=exclusive_interface_str,
                )
                op_defined_str = ""
            else:
                op_declare_str = OP_DECLARE_TEMPLATE.format(
                    op_name=op_class_name,
                    dialect_op_name=op_dialect_name,
                    interfaces=op_interfaces_str,
                    traits=op_traits_str,
                    attribute_declare=op_n_attribute_declare_str.format(
                        attribute_num=len(op_non_mutable_attribute_name_list)
                    ),
                    attribute_num=len(op_non_mutable_attribute_name_list),
                    build_args=build_args_with_muta_attr_not_input_for_declare,
                    build_mutable_attr_is_input=build_mutable_attr_is_input,
                    build_attr_num_over_1=build_attr_num_over_1,
                    get_inputs_and_outputs=op_get_inputs_outputs_str,
                    exclusive_interface=exclusive_interface_str,
                )
                attribute_names_str = (
                    '"' + '", "'.join(op_non_mutable_attribute_name_list) + '"'
                )
                op_defined_str = OP_N_ATTRIBUTE_DEFINED_TEMPLATE.format(
                    op_name=op_class_name,
                    attribute_num=len(op_non_mutable_attribute_name_list),
                    attribute_names=attribute_names_str,
                )

            # =================================== #
            #         gen GetOpInfo func str      #
            # =================================== #
            # generate get op info funciton: inputs
            input_info_list = []
            for idx in range(len(op_input_name_list)):
                input_info_list.append(
                    CONSTRUCT_INPUT_INFO_TEMPLATE.format(
                        name=op_input_name_list[idx],
                        typename=op_input_type_list[idx],
                        optional=op_input_optional_list[idx],
                        no_need_buffer=op_input_no_need_buffer_list[idx],
                        is_mutable_attribute='false',
                    )
                )
            for idx in range(len(op_mutable_attribute_name_list)):
                input_info_list.append(
                    CONSTRUCT_INPUT_INFO_TEMPLATE.format(
                        name=op_mutable_attribute_name_list[idx],
                        typename=op_mutable_attribute_type_list[idx][0],
                        optional='false',
                        no_need_buffer='false',
                        is_mutable_attribute='true',
                    )
                )
            if len(input_info_list) > 0:
                inputs_info_str = ", ".join(input_info_list)
            else:
                inputs_info_str = ""
            # generate get op info funciton: outputs
            outputs_info_str = ""
            if len(op_output_name_list) > 0:
                output_info_list = []
                for idx in range(len(op_output_name_list)):
                    output_info_list.append(
                        CONSTRUCT_OUTPUT_INFO_TEMPLATE.format(
                            name=op_output_name_list[idx],
                            typename=op_output_type_list[idx],
                            optional=op_output_optional_list[idx],
                            intermediate=op_output_intermediate_list[idx],
                        )
                    )
                outputs_info_str = ", ".join(output_info_list)
            # generate get op info funciton: attributes
            attribute_info_str = ""
            if len(op_non_mutable_attribute_name_list) > 0:
                attribute_info_list = []
                for idx in range(len(op_non_mutable_attribute_name_list)):
                    attribute_info_list.append(
                        CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format(
                            name=op_non_mutable_attribute_name_list[idx],
                            typename=op_non_mutable_attribute_type_list[idx],
                            data_type=op_non_mutable_attribute_data_type_list[
                                idx
                            ],
                        )
                    )
                attribute_info_str = ", ".join(attribute_info_list)
            # generate runtiem info
            infer_meta_func_str = ""
            infer_meta_param_str = ""
            if op_infer_meta_map is not None:
                infer_meta_func_str = op_infer_meta_map['func']
                infer_meta_param_str = '", "'.join(op_infer_meta_map['param'])

            kernel_func_str = ""
            kernel_param_str = ""
            kernel_key_dtype = ""
            if op_kernel_map is not None:
                kernel_func_str = '", "'.join(op_kernel_map['func'])
                kernel_param_str = '", "'.join(op_kernel_map['param'])
                if 'data_type' in op_kernel_map and op_kernel_map['data_type']:
                    kernel_key_dtype = '", "'.join(
                        op_kernel_map['data_type']['candidates']
                    )

            inplace_str = ""
            view_str = ""
            if op_name[-1] == "_":
                if op_inplace_map is not None:
                    for key, value in op_inplace_map.items():
                        inplace_str += '{"' + key + '", "' + value + '"},'
                    inplace_str = inplace_str[:-1]
                if op_view_map is not None:
                    for key, value in op_view_map.items():
                        view_str += '{"' + key + '", "' + value + '"},'
                    view_str = view_str[:-1]

            op_info_func_str = OP_INFO_TEMPLATE.format(
                op_name=op_class_name,
                inputs=inputs_info_str,
                attributes=attribute_info_str,
                outputs=outputs_info_str,
                infer_meta_func=infer_meta_func_str,
                infer_meta_param=infer_meta_param_str,
                kernel_func=kernel_func_str,
                kernel_param=kernel_param_str,
                kernel_key_dtype=kernel_key_dtype,
                inplace=inplace_str,
                view=view_str,
            )

            # generate op verify function str
            op_verify_str = ''
            if not op_info.custom_verify:
                op_verify_str = gen_verify_func_str(
                    op_class_name,
                    op_input_type_list,
                    op_input_optional_list,
                    op_mutable_attribute_name_list,
                    op_mutable_attribute_type_list,
                    op_non_mutable_attribute_name_list,
                    op_non_mutable_attribute_type_list,
                    op_output_type_list,
                    op_output_optional_list,
                )

            op_infer_meta_str = gen_op_infer_meta_str(op_info, op_class_name)

            # =================================== #
            #         gen Vjp func str      #
            # =================================== #

            # generate op vjp function str
            op_vjp_str = ''

            # TODO(chenzhiyang) add vjp gen code
            # if op_info.backward_name and op_info.op_phi_name[0] in vjp_interface_gen_op_list:
            #     op_vjp_str = gen_op_vjp_str(op_class_name,
            #                                 op_info.backward_name,
            #                                 op_name,
            #                                 op_info_items[op_info.op_phi_name[0]],
            #                                 op_info_items[op_info.backward_name])

            ops_name_list.append(op_class_name)
            ops_declare_list.append(op_declare_str)
            ops_defined_list.append(op_defined_str)
            ops_defined_list.append(op_info_func_str)
            ops_defined_list.append(build_func_with_muta_attr_not_input)
            ops_defined_list.append(build_func_with_attr_is_map)
            if len(op_mutable_attribute_name_list) > 0:
                ops_defined_list.append(build_func_with_muta_attr_is_input)
            ops_defined_list.append(op_verify_str)
            ops_defined_list.append(op_infer_meta_str)
            ops_defined_list.append(op_vjp_str)

    # (4) Generate head file str
    op_namespaces_prev = ""
    for name in namespaces:
        op_namespaces_prev += name + "::"
    ops_name_with_namespace_list = []
    for name in ops_name_list:
        ops_name_with_namespace_list.append(op_namespaces_prev + name)
    op_list_str = GET_OP_LIST_TEMPALTE.format(
        ", ".join(ops_name_with_namespace_list)
    )  # Add GET_OP_LIST

    declare_type_id_str = ""
    for op in ops_name_with_namespace_list:
        declare_type_id_str += DECLARE_OP_TYPE_ID.format(op_name=op)

    head_file_str = ""
    head_file_str += "".join(ops_declare_list)  # Add op class
    for name in reversed(namespaces):
        head_file_str = NAMESPACE_GARD_TEMPLATE.format(
            namespace=name, input=head_file_str
        )  # Add namespaces
    head_file_str = H_FILE_TEMPLATE.format(
        op_declare=op_list_str,
        input=head_file_str,
        declare_type_id=declare_type_id_str,
    )  # Add head

    # (5) Generate source file str
    source_file_str = "".join(ops_defined_list)  # Add op define
    for name in reversed(namespaces):
        source_file_str = NAMESPACE_GARD_TEMPLATE.format(
            namespace=name, input=source_file_str
        )  # Add namespaces

    define_type_id_str = ""
    for op in ops_name_with_namespace_list:
        define_type_id_str += DEFINE_OP_TYPE_ID.format(op_name=op)

    source_file_str = CC_FILE_TEMPLATE.format(
        h_file=op_def_h_file[:-4],
        input=source_file_str,
        define_type_id=define_type_id_str,
    )  # Add head

    # (5) Generate pd_op.h.tmp, pd_op.cc.tmp
    with open(op_def_h_file, 'a') as f:
        f.write(head_file_str)
    with open(op_def_cc_file, 'a') as f:
        f.write(source_file_str)


# =====================================
# Script parameter parsing
# =====================================
def ParseArguments():
    parser = argparse.ArgumentParser(
        description='Generate Dialect OP Definition Files By Yaml'
    )
    parser.add_argument('--op_yaml_files', type=str)
    parser.add_argument('--op_compat_yaml_file', type=str)
    parser.add_argument('--namespaces', type=str)
    parser.add_argument('--dialect_name', type=str)
    parser.add_argument('--op_def_h_file', type=str)
    parser.add_argument('--op_def_cc_file', type=str)
    return parser.parse_args()


# =====================================
# Main
# =====================================
if __name__ == "__main__":
    # parse arguments
    args = ParseArguments()
    op_yaml_files = args.op_yaml_files.split(",")
    op_compat_yaml_file = args.op_compat_yaml_file
    namespaces = []
    if args.namespaces is not None:
        namespaces = args.namespaces.split(",")
    dialect_name = args.dialect_name
    op_def_h_file = args.op_def_h_file
    op_def_cc_file = args.op_def_cc_file

    # auto code generate
    OpGenerator(
        op_yaml_files,
        op_compat_yaml_file,
        namespaces,
        dialect_name,
        op_def_h_file,
        op_def_cc_file,
    )
